import time

import numpy as np
import torch
from torch.utils.data import DataLoader
from mmdet.apis import init_detector, inference_detector

from image_dataset_sota import ImageDataset


def inference_pipeline(
    fp_input: str,
    fp_output: str,
    fp_config: str = "mmdetection/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_16e_o365tococo.py",
    fp_ckpt: str = "co_dino_5scale_swin_large_16e_o365tococo-614254c9.pth",
    n_workers: int = 2,
    pred_score_thr: float = 0.6,
):
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    model = init_detector(fp_config, fp_ckpt, device=device)

    dataset = ImageDataset(fp_input)
    loader = DataLoader(dataset, batch_size=1, num_workers=n_workers)

    with open(fp_output, "w") as f:
        f.write(f"qry_id,img_rank,n_people\n")

    st = time.time()
    for i, (qry_id, img_rank, img_data) in enumerate(loader):
        img_np = np.transpose(img_data.squeeze(0).numpy(), (1, 2, 0))
        results = inference_detector(model, img_np)

        n_ppl = 0
        if "pred_instances" in results:
            pred_instances = results.pred_instances.numpy()
            labels = pred_instances.labels.tolist()
            scores = pred_instances.scores.tolist()
            for l, s in zip(labels, scores):
                if l == 0 and s >= pred_score_thr:
                    n_ppl += 1

        with open(fp_output, "a") as f:
            f.write(f"{qry_id.numpy()[0]},{img_rank.numpy()[0]},{n_ppl}\n")

        if (i + 1) % 100 == 0:
            print(
                f"Processing 100 images took {time.time()-st}s. {(i + 1)} images have been processed"
            )
            st = time.time()


if __name__ == "__main__":
    inference_pipeline()
